{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learning Tree-augmented Naive Bayes (TAN) Structure from Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we show an example for learning the structure of a Bayesian Network using the TAN algorithm. We will first build a model to generate some data and then attempt to learn the model's graph structure back from the generated data.\n", "\n", "For comparison of Naive Bayes and TAN classifier, refer to the blog post [Classification with TAN and Pgmpy](https://loudly-soft.blogspot.com/2020/08/classification-with-tree-augmented.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## First, create a Naive Bayes graph" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import networkx as nx\n", "import matplotlib.pyplot as plt\n", "from pgmpy.models import BayesianNetwork\n", "\n", "# class variable is A and feature variables are B, C, D, E and R\n", "model = BayesianNetwork([(\"A\", \"R\"), (\"A\", \"B\"), (\"A\", \"C\"), (\"A\", \"D\"), (\"A\", \"E\")])\n", "nx.draw_circular(\n", " model, with_labels=True, arrowsize=30, node_size=800, alpha=0.3, font_weight=\"bold\"\n", ")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Second, add interaction between the features" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# feature R correlates with other features\n", "model.add_edges_from([(\"R\", \"B\"), (\"R\", \"C\"), (\"R\", \"D\"), (\"R\", \"E\")])\n", "nx.draw_circular(\n", " model, with_labels=True, arrowsize=30, node_size=800, alpha=0.3, font_weight=\"bold\"\n", ")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Then, parameterize our graph to create a Bayesian network" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from pgmpy.factors.discrete import TabularCPD\n", "\n", "# add CPD to each edge\n", "cpd_a = TabularCPD(\"A\", 2, [[0.7], [0.3]])\n", "cpd_r = TabularCPD(\n", " \"R\", 3, [[0.6, 0.2], [0.3, 0.5], [0.1, 0.3]], evidence=[\"A\"], evidence_card=[2]\n", ")\n", "cpd_b = TabularCPD(\n", " \"B\",\n", " 3,\n", " [\n", " [0.1, 0.1, 0.2, 0.2, 0.7, 0.1],\n", " [0.1, 0.3, 0.1, 0.2, 0.1, 0.2],\n", " [0.8, 0.6, 0.7, 0.6, 0.2, 0.7],\n", " ],\n", " evidence=[\"A\", \"R\"],\n", " evidence_card=[2, 3],\n", ")\n", "cpd_c = TabularCPD(\n", " \"C\",\n", " 2,\n", " [[0.7, 0.2, 0.2, 0.5, 0.1, 0.3], [0.3, 0.8, 0.8, 0.5, 0.9, 0.7]],\n", " evidence=[\"A\", \"R\"],\n", " evidence_card=[2, 3],\n", ")\n", "cpd_d = TabularCPD(\n", " \"D\",\n", " 3,\n", " [\n", " [0.3, 0.8, 0.2, 0.8, 0.4, 0.7],\n", " [0.4, 0.1, 0.4, 0.1, 0.1, 0.1],\n", " [0.3, 0.1, 0.4, 0.1, 0.5, 0.2],\n", " ],\n", " evidence=[\"A\", \"R\"],\n", " evidence_card=[2, 3],\n", ")\n", "cpd_e = TabularCPD(\n", " \"E\",\n", " 2,\n", " [[0.5, 0.6, 0.6, 0.5, 0.5, 0.4], [0.5, 0.4, 0.4, 0.5, 0.5, 0.6]],\n", " evidence=[\"A\", \"R\"],\n", " evidence_card=[2, 3],\n", ")\n", "model.add_cpds(cpd_a, cpd_r, cpd_b, cpd_c, cpd_d, cpd_e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Next, generate sample data from our Bayesian network" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Generating for node: B: 100%|██████████| 6/6 [00:00<00:00, 192.64it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ " A R B C D E\n", "0 0 1 2 1 0 0\n", "1 1 2 2 1 2 0\n", "2 0 0 2 0 2 1\n", "3 1 2 2 1 0 1\n", "4 1 2 2 1 0 1\n", "... .. .. .. .. .. ..\n", "9995 0 0 2 0 1 1\n", "9996 1 2 1 0 0 1\n", "9997 0 0 2 0 1 0\n", "9998 1 0 2 1 0 1\n", "9999 0 0 2 0 2 1\n", "\n", "[10000 rows x 6 columns]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from pgmpy.sampling import BayesianModelSampling\n", "\n", "# sample data from BN\n", "inference = BayesianModelSampling(model)\n", "df_data = inference.forward_sample(size=10000)\n", "print(df_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Now we are ready to learn the TAN structure from sample data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Building tree: 100%|██████████| 15/15.0 [00:00<00:00, 5215.93it/s]\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from pgmpy.estimators import TreeSearch\n", "\n", "# learn graph structure\n", "est = TreeSearch(df_data, root_node=\"R\")\n", "dag = est.estimate(estimator_type=\"tan\", class_node=\"A\")\n", "nx.draw_circular(\n", " dag, with_labels=True, arrowsize=30, node_size=800, alpha=0.3, font_weight=\"bold\"\n", ")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## To parameterize the learned graph from data, check out the other tutorials for more info" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[,\n", " ,\n", " ,\n", " ,\n", " ,\n", " ]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pgmpy.estimators import BayesianEstimator\n", "\n", "# there are many choices of parametrization, here is one example\n", "model = BayesianNetwork(dag.edges())\n", "model.fit(\n", " df_data, estimator=BayesianEstimator, prior_type=\"dirichlet\", pseudo_counts=0.1\n", ")\n", "model.get_cpds()" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 1 }